-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
polish(lwq): polish VAE #404
Conversation
ding/model/template/vae.py
Outdated
@@ -35,6 +41,12 @@ def loss_function(self, *inputs: Any, **kwargs) -> Tensor: | |||
|
|||
|
|||
class VanillaVAE(BaseVAE): | |||
""" | |||
Overview: | |||
Implementation of Vanilla variational autoencoder. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation of Vanilla variational autoencoder for action reconstruction
ding/model/template/vae.py
Outdated
@@ -74,17 +84,23 @@ def __init__( | |||
|
|||
def encode(self, input) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add python typing for input
ding/model/template/vae.py
Outdated
@@ -133,15 +147,14 @@ def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: | |||
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value. | |||
ReturnsKeys: | |||
- reconstruction_action (:obj:`torch.Tensor`): reconstruction_action. | |||
- predition_residual (:obj:`torch.Tensor`): predition_residual. | |||
- predition_residual (:obj:`torch.Tensor`): prediction_residual. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add more comments
ding/model/template/vae.py
Outdated
@@ -166,6 +179,21 @@ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: | |||
return eps * std + mu | |||
|
|||
def forward(self, input: Tensor, **kwargs) -> dict: | |||
""" | |||
Overview: | |||
encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Encode
Codecov Report
@@ Coverage Diff @@
## main #404 +/- ##
==========================================
+ Coverage 85.61% 85.83% +0.22%
==========================================
Files 524 524
Lines 41108 41571 +463
==========================================
+ Hits 35195 35683 +488
+ Misses 5913 5888 -25
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
ba3db61
to
c09778b
Compare
* polish(lwq): polish VAE * remove base VAE class
Description
Related Issue
TODO
Check List